{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import re\n",
    "import os\n",
    "import numpy as np\n",
    "import tqdm\n",
    "\n",
    "fout_path = 'bleus_editable.txt'\n",
    "if os.path.isfile(fout_path):\n",
    "    raise ValueError(\"File already exists\")\n",
    "\n",
    "def get_edit_bleu(checkpoint_path, edit_sample_index=None):\n",
    "    edit_sample_str = ''\n",
    "    if edit_sample_index is not None:\n",
    "        edit_sample_str = '--edit-sample-index {}'.format(edit_sample_index)\n",
    "\n",
    "    res = !python edited_generate.py data-bin/iwslt14.tokenized.de-en  \\\n",
    "           --path {checkpoint_path} \\\n",
    "           --edit-samples-path edit_iwslt14.tokenized.de-en/bpe_test.txt \\\n",
    "           {edit_sample_str} \\\n",
    "           --no-progress-bar \\\n",
    "           --beam 5 --remove-bpe --sacrebleu --moses-detokenizer de \n",
    "        \n",
    "    if edit_sample_index is not None:\n",
    "        bleu_id = -2\n",
    "    else:\n",
    "        bleu_id = -1\n",
    "\n",
    "    try:\n",
    "        bleu = float(re.findall('BLEU\\(score=(\\d+(\\.\\d+|)),', res[bleu_id])[0][0])\n",
    "    except:\n",
    "        bleu = None\n",
    "    \n",
    "    if edit_sample_index is not None:\n",
    "        try:\n",
    "            success, complexity = re.findall('EditResult\\(success=(True|False), complexity=(\\d+)\\)', res[-1])[0]\n",
    "            success = eval(success)\n",
    "            complexity = int(complexity)\n",
    "        except:\n",
    "            success = complexity = None\n",
    "        \n",
    "        with open(fout_path, 'a') as f:\n",
    "            print(bleu, success, complexity, file=f)\n",
    "    else:\n",
    "        success = complexity = None\n",
    "    \n",
    "    return bleu, success, complexity"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "get_edit_bleu('checkpoints_editable/checkpoint_best.pt')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "np.random.seed(42)\n",
    "ids = np.random.permutation(6749)[:1000]\n",
    "[get_edit_bleu('checkpoints_editable/checkpoint_best.pt', id) for id in tqdm.tqdm_notebook(ids)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "bleus = []\n",
    "successes = []\n",
    "complexities = []\n",
    "with open('bleus_samples_stability1_kl_almost_last.txt') as f:\n",
    "    for line in f:\n",
    "        bleu, success, complexity = line[:-1].split()\n",
    "        \n",
    "        try:\n",
    "            bleu = float(bleu)\n",
    "            success = bool(success)\n",
    "            complexity = float(complexity)\n",
    "        except:\n",
    "            bleu = 0\n",
    "            success = False\n",
    "            complexity = -1\n",
    "        \n",
    "        bleus.append(bleu)\n",
    "        successes.append(success)\n",
    "        complexities.append(complexity)\n",
    "        \n",
    "bleus = np.array(bleus)\n",
    "successes = np.array(successes)\n",
    "complexities = np.array(complexities)\n",
    "\n",
    "print('Mean BLEU:', np.mean(bleus))\n",
    "print('Success rate:', np.mean(successes))\n",
    "print('Mean complexity:', np.mean(complexities))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
